Source code for nlp_architect.models.bist.utils

# ******************************************************************************
# Copyright 2017-2018 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ******************************************************************************
# pylint: disable=deprecated-module
import os
import subprocess
from collections import Counter

from nlp_architect.data.conll import ConllEntry
from nlp_architect.models.bist.eval.conllu.conll17_ud_eval import run_conllu_eval


# Things that were changed from the original:
# - Removed ConllEntry class, normalize()
# - Changed read_conll() and write_conll() input from file to path
# - Added run_eval(), get_options_dict() and is_conllu()
# - Reformatted code and variable names to conform with PEP8
# - Added legal header


[docs]def vocab(conll_path): # pylint: disable=missing-docstring words_count = Counter() pos_count = Counter() rel_count = Counter() for sentence in read_conll(conll_path): words_count.update([node.norm for node in sentence if isinstance(node, ConllEntry)]) pos_count.update([node.pos for node in sentence if isinstance(node, ConllEntry)]) rel_count.update([node.relation for node in sentence if isinstance(node, ConllEntry)]) return words_count, {w: i for i, w in enumerate(words_count.keys())}, list( pos_count.keys()), list(rel_count.keys())
[docs]def read_conll(path): """Yields CoNLL sentences read from CoNLL formatted file..""" with open(path, 'r') as conll_fp: root = ConllEntry(0, '*root*', '*root*', 'ROOT-POS', 'ROOT-CPOS', '_', -1, 'rroot', '_', '_') tokens = [root] for line in conll_fp: stripped_line = line.strip() tok = stripped_line.split('\t') if not tok or line.strip() == '': if len(tokens) > 1: yield tokens tokens = [root] else: if line[0] == '#' or '-' in tok[0] or '.' in tok[0]: # noinspection PyTypeChecker tokens.append(stripped_line) else: tokens.append( ConllEntry(int(tok[0]), tok[1], tok[2], tok[4], tok[3], tok[5], int(tok[6]) if tok[6] != '_' else -1, tok[7], tok[8], tok[9])) if len(tokens) > 1: yield tokens
[docs]def write_conll(path, conll_gen): """Writes CoNLL sentences to CoNLL formatted file.""" with open(path, 'w') as file: for sentence in conll_gen: for entry in sentence[1:]: file.write(str(entry) + '\n') file.write('\n')
[docs]def run_eval(gold, test): """Evaluates a set of predictions using the appropriate script.""" if is_conllu(gold): run_conllu_eval(gold_file=gold, test_file=test) else: eval_script = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'eval', 'eval.pl') with open(test[:test.rindex('.')] + '_eval.txt', 'w') as out_file: subprocess.run(['perl', eval_script, '-g', gold, '-s', test], stdout=out_file)
[docs]def is_conllu(path): """Determines if the file is in CoNLL-U format.""" return os.path.splitext(path.lower())[1] == '.conllu'
[docs]def get_options_dict(activation, lstm_dims, lstm_layers, pos_dims): """Generates dictionary with all parser options.""" return {'activation': activation, 'lstm_dims': lstm_dims, 'lstm_layers': lstm_layers, 'pembedding_dims': pos_dims, 'wembedding_dims': 100, 'rembedding_dims': 25, 'hidden_units': 100, 'hidden2_units': 0, 'learning_rate': 0.1, 'blstmFlag': True, 'labelsFlag': True, 'bibiFlag': True, 'costaugFlag': True, 'seed': 0, 'mem': 0}